from langchain_core.tools import tool
from core.testAgent.CKGRetriever import CKGRetriever
from typing import List, Tuple, Any, Union, Optional
import contextvars

from core.testAgent.Node import Clazz
from core.testAgent.HumanAssistance import get_human_assistance_manager
from core.testAgent.ProgressTracker import progress_tracker

# CKGRetriever 是单例模式生产，全局只有一个
graph_retriever = CKGRetriever("bolt://localhost:7687", "neo4j", "12345678")

_job_context = contextvars.ContextVar("retriever_job_id", default=None)


def bind_tool_job_context(job_id: Optional[str]):
    _job_context.set(job_id)


@tool
def find_variable_definition(variable_name: str, full_qualified_name: str = None) -> Union[str, list[dict[str, Any]]]:
    """
    Finds up to three variable definitions for a given name in the code knowledge graph.

    Args:
        variable_name (str): The variable name to search.
        full_qualified_name (str, optional): Fully qualified name (e.g., namespace or class).

    Returns:
        list: A list of dictionaries with keys:
              - "name" (str): Variable name.
              - "full_qualified_name" (str): Fully qualified name.
              - "type" (str): Data type.
              - "modifiers" (str): Modifiers.
              - "class_name" (str): Class name.
              - "summarization" (str, optional): Summarization of the variable.

    Example Usage:
        - find_variable_definition("result", "com.example.Calculator.result")
        - find_variable_definition("count")
    """
    ret = graph_retriever.search_variable_query(variable_name, full_qualified_name)
    if ret == "NO MATCH FOR SOURCE VARIABLE":
        return (
            f"No Variable node found for the given variable name: {variable_name}, full qualified name: {full_qualified_name}."
            f"If you unsure whether it is a method, variable or class, use `fuzzy_search`.")
    assert isinstance(ret, list)
    nodes = ret[:3]
    tuples = []
    for node in nodes:
        dict_item = {"name": node.name,
                     "full_qualified_name": node.full_qualified_name,
                     "type": node.data_type,
                     "modifiers": node.modifiers,
                     "class_name": node.class_name}
        if node.summarization:
            dict_item["summarization"] = node.summarization
        else:
            dict_item["content"] = node.content
        tuples.append(dict_item)
    return tuples


# Agent 有时会传入两个参数（例如method_name和full_qualified_name），需要修改查询函数为支持两个参数的查询（已完成）
@tool
def find_method_definition(method_name: str, full_qualified_name: str = None, params: list = None) -> Union[str, list[dict[str, Any]], list[dict[str, Any]]]:
    """
    Finds method definitions in the code knowledge graph based on the given parameters.

    Args:
        method_name (str): Mandatory. The name of the method to search.
        full_qualified_name (str, optional): Fully qualified name (e.g., namespace or class).
        params (list, optional): A list of parameter types.

    Returns:
        List[dict]: A list of dictionaries with keys:
            - "name" (str): Method name.
            - "full_qualified_name" (str): Fully qualified name.
            - "params" (list): Parameters.
            - "return_type" (str): Return type.
            - "content" (str): Method body.
            - "comment" (str): Associated comment.

    Functionality:
        - Exact match is performed if all three parameters are provided.
        - If only `method_name` is provided, a fuzzy search returns up to three results.

    Example Usage:
        - find_method_definition("add")
        - find_method_definition("add", "com.example.Calculator.add")
        - find_method_definition("add", "com.example.Calculator.add", ["int", "int"])
    """
    ret = graph_retriever.search_method_query(method_name, full_qualified_name, params)
    if ret == "NO MATCH FOR SOURCE METHOD":
        return (
            f"No Method node found for the given method name: {method_name}, full qualified name: {full_qualified_name}, and params: {params}."
            f"If you unsure whether it is a method, variable or class, use `fuzzy_search`.")

    assert isinstance(ret, list)
    nodes = ret[:3]
    tuples = []
    for node in nodes:
        dict_item = {"name": node.name,
                     "full_qualified_name": node.full_qualified_name,
                     "params": node.params,
                     "return_type": node.return_type}
        if node.summarization:
            dict_item["summarization"] = node.summarization
        else:
            dict_item["content"] = node.content
        tuples.append(dict_item)
    return tuples


@tool
def find_class(class_name: str, full_qualified_name: str = None) -> Union[str, dict[str, Any]]:
    """
    Finds the most relevant class node for a given class name in the code knowledge graph.

    Args:
        class_name (str): The class name to search.
        full_qualified_name (str, optional): Fully qualified name.

    Returns:
        dict: A dictionary with the following keys:
            - "name" (str): Class name.
            - "full_qualified_name" (str): Fully qualified name.
            - "modifiers" (str): Modifiers.
            - "parent_classes" (List[str]): Parent classes.
            - "sub_classes" (List[str]): Sub classes.
            - "summarization" (str, optional): Summarization of the class.

    Example Usage:
        - find_class("Calculator", "com.example.Calculator")
        - find_class("MyClass")
    """
    node = graph_retriever.search_clazz_query(class_name, full_qualified_name)
    if isinstance(node, str):
        return (f"{node}: {class_name}.\n"
                f"If you unsure whether it is a method, variable or class, use `fuzzy_search`.")
    parent_classes = graph_retriever.search_parent_clazz_query(class_name)
    if parent_classes == "NO MATCH FOR PARENT CLAZZ":
        parent_classes = []
    sub_classes = graph_retriever.search_sub_clazz_query(class_name)
    if sub_classes == "NO MATCH FOR SUB CLAZZ":
        sub_classes = []

    assert isinstance(node, Clazz)
    ret = {"name": node.name,
           "full_qualified_name": node.full_qualified_name,
           "modifiers": node.modifiers,
           "parent_classes": [parent.name for parent in parent_classes],
           "sub_classes": [sub.name for sub in sub_classes]}
    if node.summarization:
        ret["summarization"] = node.summarization
    else:
        ret["simple_content"] = node.simple_content
        ret["comment"] = node.comment
    return ret


@tool
def find_method_calls(method_name: str, full_qualified_name: str, method_params: List[str]) -> Union[str, list[dict[str, Any]]]:
    """
    Finds occurrences of a given method being called in the code knowledge graph.

    Args:
        method_name (str): Name of the method to search.
        full_qualified_name (str): Fully qualified name of the method (e.g., class and namespace).
        method_params (List[str]): Parameter types to distinguish overloaded methods.

    Returns:
        List[dict]: A list of dictionaries with keys:
            - "name" (str): Method name.
            - "full_qualified_name" (str): Fully qualified name.
            - "params" (List[str]): Parameters.
            - "return_type" (str): Return type.
            - "summarization" (str, optional): Summarization of the method.

    Example Usage:
        - find_method_calls("sub", "com.example.Calculator.sub", ["int", "int"])
        - find_method_calls("toString", "com.example.MyClass.toString", [])
    """
    ret = graph_retriever.method_calls_query(method_name, full_qualified_name, method_params)
    if ret == "NO MATCH FOR SOURCE METHOD":
        return (
            f"No Method node found for the given method name: {method_name}, full qualified name: {full_qualified_name}, and params: {method_params}."
            f"Try to use `find_method_definition` to find the method first.")
    if ret == "MULTIPLE MATCHES FOR SOURCE METHOD":
        return (
            f"Multiple Method nodes found for the given method name: {method_name}, full qualified name: {full_qualified_name}, and params: {method_params}."
            f"Try to use `find_method_definition` to find the method first.")
    if ret == "NO MATCH FOR TARGET METHOD":
        return f"No Method calls found for the given method name: {method_name}, full qualified name: {full_qualified_name}, and params: {method_params}."

    assert isinstance(ret, list)
    tuples = []
    for item in ret:
        dict_item = {"name": item.name,
                     "full_qualified_name": item.full_qualified_name}
        if hasattr(item, "params") and item.params:
            dict_item["params"] = item.params
        if hasattr(item, "return_type") and item.return_type:
            dict_item["return_type"] = item.return_type
        if item.summarization:
            dict_item["summarization"] = item.summarization
        tuples.append(dict_item)
    return tuples


@tool
def find_method_usages(method_name: str, full_qualified_name: str, method_params: List[str]) -> Union[str, list[dict[str, Any]]]:
    """
    Finds up to three usages of a given method in the code knowledge graph.

    Args:
        method_name (str): Name of the method to search.
        full_qualified_name (str): Fully qualified name of the method (e.g., class and namespace).
        method_params (List[str]): Parameter types to distinguish overloaded methods.

    Returns:
        List[dict]: A list of dictionaries with keys:
            - "node_type" (str): Type of the node (e.g., "Method", "Clazz", "Variable").
            - "name" (str): Method name.
            - "full_qualified_name" (str): Fully qualified name.
            - "params" (List[str]): Parameters.
            - "return_type" (str): Return type.
            - "summarization" (str, optional): Summarization of the node.

    Example Usage:
        - find_method_usages("add", "com.example.Calculator.add", ["int", "int"])
        - find_method_usages("toString", "com.example.MyClass.toString", [])
    """
    ret = graph_retriever.method_usages_query(method_name, full_qualified_name, method_params)
    if ret == "NO MATCH FOR SOURCE METHOD":
        return (
            f"No Method node found for the given method name: {method_name}, full qualified name: {full_qualified_name}, and params: {method_params}."
            f"Try to use `find_method_definition` to find the method first.")
    if ret == "MULTIPLE MATCHES FOR SOURCE METHOD":
        return (
            f"Multiple Method nodes found for the given method name: {method_name}, full qualified name: {full_qualified_name}, and params: {method_params}."
            f"Try to use `find_method_definition` to find the method first.")
    if ret == "NO MATCH FOR TARGET METHOD":
        return f"No Method usages found for the given method name: {method_name}, full qualified name: {full_qualified_name}, and params: {method_params}."

    assert isinstance(ret, list)
    tuples = []
    for item in ret:
        dict_item = {"node_type": type(item).__name__, "name": item.name,
                     "full_qualified_name": item.full_qualified_name,
                     "params": item.params, "return_type": item.return_type}
        if item.summarization:
            dict_item["summarization"] = item.summarization
        tuples.append(dict_item)
    return tuples


@tool
def fuzzy_search(name: str):
    """
    Fuzzy search for a given name in the code knowledge graph.
    It is recommended to use this function when you first search for a member in the code knowledge graph.

    Args:
        name (str): The name to search.

    Returns:
        List[dict]: A list of dictionaries with keys:
            - "name" (str): Name of the member.
            - "type" (str): Type of the member (e.g., "Method", "Variable", "Class").
            - "full_qualified_name" (str): Fully qualified name.
    """
    nodes = graph_retriever.fuzzy_search(name)
    if isinstance(nodes, str) or len(nodes) == 0:
        return f"No nodes found for the given name: {name}."
    ret = [{"name": node.name,
            "type": type(node).__name__,
            "full_qualified_name": node.full_qualified_name} for node in nodes]
    ret = ret[:20]
    return ret


@tool
def search_similarity_test_class(test_class_name: str):
    """
    Finds the most relevant class node for a given test class name in the code knowledge graph.

    Args:
        test_class_name (str): The test class name to search.

    Returns:
        List[dict]: A list of dictionaries with keys:
            - "name" (str): Relevant class name.
            - "content" (str): Content of the class.

    """
    nodes = graph_retriever.search_similarity_test_class(test_class_name)
    if nodes is None:
        return f"No nodes found for the given name: {test_class_name}."
    if len(nodes) == 0:
        return f"No relevant nodes found for the given name: {test_class_name}."
    ret = [{"name": node[0].name, "content": node[0].content, "similarity": round(node[1], 2)} for node in nodes[:3]]
    return ret

# @tool
# def find_import_statements():
#     """pass"""
#     pass
# a = graph_retriever.fuzzy_search("INSTANCE")
# print(a)
#
# b = fuzzy_search("INSTANCE")
# print(b)


@tool
def human_assistance(prompt: str, context: str = "", artifact: str = "") -> dict:
    """
    Request help from a human collaborator when the agent is uncertain.

    Args:
        prompt (str): What the agent needs clarified (e.g., confirm test purpose, validate generated test cases).
        context (str, optional): Additional background to display to the human.
        artifact (str, optional): Partial outputs that the human should review.

    Returns:
        dict: Contains the assistance request id, status, and the human-provided content when available.
    """
    prompt_clean = (prompt or "").strip()
    if not prompt_clean:
        return {"status": "error", "message": "Human assistance prompt 不能为空，请描述需要帮助的问题。"}
    if len(prompt_clean) > 200:
        return {"status": "error", "message": "Human assistance 请求过长，请压缩至 200 个字符以内后再试。"}
    context_clean = (context or "").strip()
    manager = get_human_assistance_manager()
    job_id = _job_context.get()

    def _notify(request):
        if job_id:
            progress_tracker.record_assistance_request(job_id, request["id"], prompt_clean, context_clean)

    result = manager.request_help(prompt_clean, context=context_clean, artifact=artifact, on_request=_notify)
    if result.get("status") == "answered":
        return result
    if job_id:
        progress_tracker.record_message(job_id, "Human assistance 未在规定时间内回复，继续自动尝试。")
    return {
        "status": result.get("status", "timeout"),
        "message": result.get("message", "No human response received within timeout.")
    }
